"""
TFTModel Sub-Modules
--------------------

Implementation of ``nn.Modules`` for Temporal Fusion Transformer from PyTorch-Forecasting:
https://github.com/jdb78/pytorch-forecasting

PyTorch Forecasting v0.9.1 License from https://github.com/jdb78/pytorch-forecasting/blob/master/LICENSE, accessed
on Wed, November 3, 2021:
'THE MIT License

Copyright 2020 Jan Beitner

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
'
"""

from typing import Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from darts.logging import get_logger
from darts.utils.torch import MonteCarloDropout

logger = get_logger(__name__)

HiddenState = Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]


def get_embedding_size(n: int, max_size: int = 100) -> int:
    """
    Determine empirically good embedding sizes (formula taken from fastai).
    Args:
        n (int): number of classes
        max_size (int, optional): maximum embedding size. Defaults to 100.
    Returns:
        int: embedding size
    """
    if n > 2:
        return min(round(1.6 * n**0.56), max_size)
    else:
        return 1


class _TimeDistributedEmbeddingBag(nn.EmbeddingBag):
    def __init__(self, *args, batch_first: bool = False, **kwargs):
        super().__init__(*args, **kwargs)
        self.batch_first = batch_first

    def forward(self, x):  # pyright: ignore[reportIncompatibleMethodOverride]
        if len(x.size()) <= 2:
            return super().forward(x)

        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().view(
            -1, x.size(-1)
        )  # (samples * timesteps, input_size)

        y = super().forward(x_reshape)

        # We have to reshape Y
        if self.batch_first:
            y = y.contiguous().view(
                x.size(0), -1, y.size(-1)
            )  # (samples, timesteps, output_size)
        else:
            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)
        return y


class _MultiEmbedding(nn.Module):
    def __init__(
        self,
        embedding_sizes: dict[str, tuple[int, int]],
        variable_names: list[str],
    ):
        """Embedding layer for categorical variables including groups of categorical variables.
        Enabled for static and dynamic categories (i.e. 3 dimensions for batch x time x categories).

        Parameters
        ----------
        embedding_sizes
            dictionary of embedding sizes, e.g. ``{'cat1': (10, 3)}``
            indicates that the first categorical variable has 10 unique values which are mapped to 3 embedding
            dimensions. Use :py:func:`~pytorch_forecasting.utils.get_embedding_size` to automatically obtain
            reasonable embedding sizes depending on the number of categories.
        variable_names
            list of categorical variable names to ensure ordered iterations.
        """
        super().__init__()
        self.embedding_sizes = embedding_sizes
        self.variable_names = variable_names

        self.embeddings = nn.ModuleDict({
            name: nn.Embedding(*embedding_sizes[name]) for name in variable_names
        })

    @property
    def input_size(self) -> int:
        return len(self.variable_names)

    @property
    def output_size(self) -> dict[str, int]:
        return {name: sizes[1] for name, sizes in self.embedding_sizes.items()}

    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        x
            input tensor of shape batch x (optional) time x categoricals in the order of ``variable_names``.

        Returns
        -------
        dict
            dictionary of category names to embeddings of shape batch x (optional) time x embedding_size if
            ``embedding_size`` is given as dictionary.
        """
        return {
            name: self.embeddings[name](x[..., i])
            for i, name in enumerate(self.variable_names)
        }


class _TimeDistributedInterpolation(nn.Module):
    def __init__(
        self, output_size: int, batch_first: bool = False, trainable: bool = False
    ):
        super().__init__()
        self.output_size = output_size
        self.batch_first = batch_first
        self.trainable = trainable
        if self.trainable:
            self.mask = nn.Parameter(torch.zeros(self.output_size, dtype=torch.float32))
            self.gate = nn.Sigmoid()

    def interpolate(self, x):
        upsampled = F.interpolate(
            x.unsqueeze(1), self.output_size, mode="linear", align_corners=True
        ).squeeze(1)
        if self.trainable:
            upsampled = upsampled * self.gate(self.mask.unsqueeze(0)) * 2.0
        return upsampled

    def forward(self, x):
        if len(x.size()) <= 2:
            return self.interpolate(x)

        # Squash samples and timesteps into a single axis
        x_reshape = x.contiguous().view(
            -1, x.size(-1)
        )  # (samples * timesteps, input_size)

        y = self.interpolate(x_reshape)

        # We have to reshape Y
        if self.batch_first:
            y = y.contiguous().view(
                x.size(0), -1, y.size(-1)
            )  # (samples, timesteps, output_size)
        else:
            y = y.view(-1, x.size(1), y.size(-1))  # (timesteps, samples, output_size)

        return y


class _GatedLinearUnit(nn.Module):
    """Gated Linear Unit"""

    def __init__(
        self,
        input_size: int,
        hidden_size: Optional[int] = None,
        dropout: Optional[float] = None,
    ):
        super().__init__()

        if dropout is not None:
            self.dropout = MonteCarloDropout(dropout)
        else:
            self.dropout = dropout
        self.hidden_size = hidden_size or input_size
        self.fc = nn.Linear(input_size, self.hidden_size * 2)

        self.init_weights()

    def init_weights(self):
        for n, p in self.named_parameters():
            if "bias" in n:
                torch.nn.init.zeros_(p)
            elif "fc" in n:
                torch.nn.init.xavier_uniform_(p)

    def forward(self, x):
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.fc(x)
        x = F.glu(x, dim=-1)
        return x


class _ResampleNorm(nn.Module):
    def __init__(
        self,
        input_size: int,
        output_size: Optional[int] = None,
        trainable_add: bool = True,
        norm: type[nn.Module] = nn.LayerNorm,
        skip_interpolation: bool = False,
    ):
        super().__init__()

        self.input_size = input_size
        self.trainable_add = trainable_add
        self.output_size = output_size or input_size

        if self.input_size != self.output_size:
            if skip_interpolation:
                self.resample = nn.Linear(self.input_size, self.output_size)
            else:
                self.resample = _TimeDistributedInterpolation(
                    self.output_size, batch_first=True, trainable=False
                )

        if self.trainable_add:
            self.mask = nn.Parameter(torch.zeros(self.output_size, dtype=torch.float))
            self.gate = nn.Sigmoid()
        self.norm = norm(self.output_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.input_size != self.output_size:
            x = self.resample(x)

        if self.trainable_add:
            x = x * self.gate(self.mask) * 2.0

        output = self.norm(x)
        return output


class _AddNorm(nn.Module):
    def __init__(
        self,
        input_size: int,
        skip_size: Optional[int] = None,
        trainable_add: bool = True,
        norm: type[nn.Module] = nn.LayerNorm,
        skip_interpolation: bool = False,
    ):
        super().__init__()

        self.input_size = input_size
        self.trainable_add = trainable_add
        self.skip_size = skip_size or input_size

        if self.input_size != self.skip_size:
            if skip_interpolation:
                self.resample = nn.Linear(self.skip_size, self.input_size)
            else:
                self.resample = _TimeDistributedInterpolation(
                    self.input_size, batch_first=True, trainable=False
                )

        if self.trainable_add:
            self.mask = nn.Parameter(torch.zeros(self.input_size, dtype=torch.float))
            self.gate = nn.Sigmoid()
        self.norm = norm(self.input_size)

    def forward(self, x: torch.Tensor, skip: torch.Tensor):
        if self.input_size != self.skip_size:
            skip = self.resample(skip)

        if self.trainable_add:
            skip = skip * self.gate(self.mask) * 2.0

        output = self.norm(x + skip)
        return output


class _GateAddNorm(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: Optional[int] = None,
        skip_size: Optional[int] = None,
        trainable_add: bool = False,
        dropout: Optional[float] = None,
        layer_norm: type[nn.Module] = nn.LayerNorm,
        skip_interpolation: bool = False,
    ):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size or input_size
        self.skip_size = skip_size or self.hidden_size
        self.dropout = dropout

        self.glu = _GatedLinearUnit(
            self.input_size, hidden_size=self.hidden_size, dropout=self.dropout
        )
        self.add_norm = _AddNorm(
            self.hidden_size,
            skip_size=self.skip_size,
            trainable_add=trainable_add,
            norm=layer_norm,
            skip_interpolation=skip_interpolation,
        )

    def forward(self, x, skip):
        output = self.glu(x)
        output = self.add_norm(output, skip)
        return output


class _GatedResidualNetwork(nn.Module):
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        output_size: int,
        dropout: float = 0.1,
        context_size: Optional[int] = None,
        residual: bool = False,
        layer_norm: type[nn.Module] = nn.LayerNorm,
        skip_interpolation: bool = False,
    ):
        super().__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.context_size = context_size
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.residual = residual

        if (self.input_size != self.output_size) and (not self.residual):
            self.do_resampling = True
            self.resample_norm = _ResampleNorm(
                input_size=self.input_size,
                output_size=self.output_size,
                norm=layer_norm,
                skip_interpolation=skip_interpolation,
            )
        else:
            self.do_resampling = False

        self.fc1 = nn.Linear(self.input_size, self.hidden_size)
        self.elu = nn.ELU()

        if self.context_size is not None:
            self.context = nn.Linear(self.context_size, self.hidden_size, bias=False)

        self.fc2 = nn.Linear(self.hidden_size, self.hidden_size)
        self.init_weights()

        self.gate_norm = _GateAddNorm(
            input_size=self.hidden_size,
            skip_size=self.output_size,
            hidden_size=self.output_size,
            dropout=self.dropout,
            trainable_add=False,
            skip_interpolation=skip_interpolation,
        )

    def init_weights(self):
        for name, p in self.named_parameters():
            if "bias" in name:
                torch.nn.init.zeros_(p)
            elif "fc1" in name or "fc2" in name:
                torch.nn.init.kaiming_normal_(
                    p, a=0, mode="fan_in", nonlinearity="leaky_relu"
                )
            elif "context" in name:
                torch.nn.init.xavier_uniform_(p)

    def forward(self, x, context=None, residual=None):
        if residual is None:
            residual = x

        if self.do_resampling:
            residual = self.resample_norm(residual)

        x = self.fc1(x)
        if context is not None:
            context = self.context(context)
            x = x + context
        x = self.elu(x)
        x = self.fc2(x)
        x = self.gate_norm(x, residual)
        return x


class _VariableSelectionNetwork(nn.Module):
    def __init__(
        self,
        input_sizes: dict[str, int],
        hidden_size: int,
        input_embedding_flags: Optional[dict[str, bool]] = None,
        dropout: float = 0.1,
        context_size: Optional[int] = None,
        single_variable_grns: Optional[dict[str, _GatedResidualNetwork]] = None,
        prescalers: Optional[dict[str, nn.Linear]] = None,
        layer_norm: type[nn.Module] = nn.LayerNorm,
        skip_interpolation: bool = False,
    ):
        """
        Calculate weights for ``num_inputs`` variables  which are each of size ``input_size``
        """
        super().__init__()

        input_embedding_flags = (
            input_embedding_flags if input_embedding_flags is not None else {}
        )
        single_variable_grns = (
            single_variable_grns if single_variable_grns is not None else {}
        )
        prescalers = prescalers if prescalers is not None else {}

        self.hidden_size = hidden_size
        self.input_sizes = input_sizes
        self.input_embedding_flags = input_embedding_flags
        self.dropout = dropout
        self.context_size = context_size

        if self.num_inputs > 1:
            if self.context_size is not None:
                self.flattened_grn = _GatedResidualNetwork(
                    self.input_size_total,
                    min(self.hidden_size, self.num_inputs),
                    self.num_inputs,
                    self.dropout,
                    self.context_size,
                    residual=False,
                    skip_interpolation=skip_interpolation,
                )
            else:
                self.flattened_grn = _GatedResidualNetwork(
                    self.input_size_total,
                    min(self.hidden_size, self.num_inputs),
                    self.num_inputs,
                    self.dropout,
                    residual=False,
                    skip_interpolation=skip_interpolation,
                )

        self.single_variable_grns = nn.ModuleDict()
        self.prescalers = nn.ModuleDict()
        for name, input_size in self.input_sizes.items():
            if name in single_variable_grns:
                self.single_variable_grns[name] = single_variable_grns[name]
            elif self.input_embedding_flags.get(name, False):
                self.single_variable_grns[name] = _ResampleNorm(
                    input_size=input_size,
                    output_size=self.hidden_size,
                    norm=layer_norm,
                    skip_interpolation=skip_interpolation,
                )
            else:
                self.single_variable_grns[name] = _GatedResidualNetwork(
                    input_size,
                    min(input_size, self.hidden_size),
                    output_size=self.hidden_size,
                    dropout=self.dropout,
                    skip_interpolation=skip_interpolation,
                )
            if name in prescalers:  # reals need to be first scaled up
                self.prescalers[name] = prescalers[name]
            elif not self.input_embedding_flags.get(name, False):
                self.prescalers[name] = nn.Linear(1, input_size)

        self.softmax = nn.Softmax(dim=-1)

    @property
    def input_size_total(self):
        return sum(
            size if name in self.input_embedding_flags else size
            for name, size in self.input_sizes.items()
        )

    @property
    def num_inputs(self):
        return len(self.input_sizes)

    def forward(
        self, x: dict[str, torch.Tensor], context: Optional[torch.Tensor] = None
    ):
        if self.num_inputs > 1:
            # transform single variables
            var_outputs = []
            weight_inputs = []
            for name in self.input_sizes.keys():
                # select embedding belonging to a single input
                variable_embedding = x[name]
                if name in self.prescalers:
                    variable_embedding = self.prescalers[name](variable_embedding)
                weight_inputs.append(variable_embedding)
                var_outputs.append(self.single_variable_grns[name](variable_embedding))
            var_outputs = torch.stack(var_outputs, dim=-1)

            # calculate variable weights
            flat_embedding = torch.cat(weight_inputs, dim=-1)
            sparse_weights = self.flattened_grn(flat_embedding, context)
            sparse_weights = self.softmax(sparse_weights).unsqueeze(-2)

            outputs = var_outputs * sparse_weights
            outputs = outputs.sum(dim=-1)
        else:  # for one input, do not perform variable selection but just encoding
            name = next(iter(self.single_variable_grns.keys()))
            variable_embedding = x[name]
            if name in self.prescalers:
                variable_embedding = self.prescalers[name](variable_embedding)
            outputs = self.single_variable_grns[name](
                variable_embedding
            )  # fast forward if only one variable
            if outputs.ndim == 3:  # -> batch size, time, hidden size, n_variables
                sparse_weights = torch.ones(
                    outputs.size(0), outputs.size(1), 1, 1, device=outputs.device
                )  #
            else:  # ndim == 2 -> batch size, hidden size, n_variables
                sparse_weights = torch.ones(
                    outputs.size(0), 1, 1, device=outputs.device
                )
        return outputs, sparse_weights


class _ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout: Optional[float] = None, scale: bool = True):
        super().__init__()
        if dropout is not None:
            self.dropout = MonteCarloDropout(p=dropout)
        else:
            self.dropout = dropout
        self.softmax = nn.Softmax(dim=2)
        self.scale = scale

    def forward(self, q, k, v, mask=None):
        attn = torch.bmm(q, k.permute(0, 2, 1))  # query-key overlap

        if self.scale:
            dimension = torch.sqrt(torch.tensor(k.shape[-1]).to(torch.float32))
            attn = attn / dimension

        if mask is not None:
            attn = attn.masked_fill(mask, -1e9)
        attn = self.softmax(attn)

        if self.dropout is not None:
            attn = self.dropout(attn)
        output = torch.bmm(attn, v)
        return output, attn


class _InterpretableMultiHeadAttention(nn.Module):
    def __init__(self, n_head: int, d_model: int, dropout: float = 0.0):
        super().__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.d_k = self.d_q = self.d_v = d_model // n_head
        self.dropout = MonteCarloDropout(p=dropout)

        self.v_layer = nn.Linear(self.d_model, self.d_v)
        self.q_layers = nn.ModuleList([
            nn.Linear(self.d_model, self.d_q) for _ in range(self.n_head)
        ])
        self.k_layers = nn.ModuleList([
            nn.Linear(self.d_model, self.d_k) for _ in range(self.n_head)
        ])
        self.attention = _ScaledDotProductAttention()
        self.w_h = nn.Linear(self.d_v, self.d_model, bias=False)

        self.init_weights()

    def init_weights(self):
        for name, p in self.named_parameters():
            if "bias" not in name:
                torch.nn.init.xavier_uniform_(p)
            else:
                torch.nn.init.zeros_(p)

    def forward(self, q, k, v, mask=None) -> tuple[torch.Tensor, torch.Tensor]:
        heads = []
        attns = []
        vs = self.v_layer(v)
        for i in range(self.n_head):
            qs = self.q_layers[i](q)
            ks = self.k_layers[i](k)
            head, attn = self.attention(qs, ks, vs, mask)
            head_dropout = self.dropout(head)
            heads.append(head_dropout)
            attns.append(attn)

        head = torch.stack(heads, dim=2) if self.n_head > 1 else heads[0]
        attn = torch.stack(attns, dim=2)

        outputs = torch.mean(head, dim=2) if self.n_head > 1 else head
        outputs = self.w_h(outputs)
        outputs = self.dropout(outputs)

        return outputs, attn
